from ..StrategiesConstructor import StrategiesConstructor
from geesibling.core.types import Graph, Node
from .cost_graph import CostGraph
from typing import Dict, List, Tuple, Union
import numpy as np
import warnings
import time
try:
    import pulp
    from pulp import LpMinimize, LpProblem, LpStatus, LpVariable, lpDot, lpSum
except:
    warnings.warn(f'please install the pulp')

INFINITY_COST = 1e13
class Solver:

    def __init__(self,
                 graph: Graph,
                 strategies_constructor: StrategiesConstructor,
                 cost_graph: CostGraph,
                #  graph_analyser: GraphAnalyser = None,
                #  memory_budget: float = -1.0,
                 solution_numbers: int = 1,
                 forward_only: bool = True):
                #  memory_increasing_coefficient: float = 1.3,
                #  verbose=False):
        '''
        Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
        Argument:
            graph: The computing graph to be optimized.
            strategies_constructor: It will provide all the possible strategies for each node in the computing graph.
            cost_graph: A graph data structure to simplify the edge cost graph.
            graph_analyser: graph_analyser will analyses the graph to obtain the variable liveness information, which will be used to generate memory constraints.
            memory_budget: Memory constraint for the solution.
            solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget.
            memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget.
        '''
        self.graph = graph
        self.strategies_constructor = strategies_constructor
        self.cost_graph = cost_graph
        # self.graph_analyser = graph_analyser  #生成关于memory的约束
        self.leaf_strategies = self.strategies_constructor.leaf_strategies
        self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies]
        self.strategy_map = self.strategies_constructor.strategy_map
        # self.memory_budget = memory_budget
        # self.solution_numbers = solution_numbers
        self.forward_only = forward_only
        # if self.solution_numbers > 1:
        #     self.memory_increasing_coefficient = memory_increasing_coefficient
        # else:
        #     self.memory_increasing_coefficient = 1
        # temporarily we use all nodes as liveness list, we count the backward memory cost together with
        # forward memory cost into the node memory cost, and no activation checkpoint is used in this phase.
        # self.liveness_list = self.graph_analyser.liveness_analysis()
        self.liveness_list = self.nodes
        self.node_index_dict = self._generate_node_index_dict()

        # The last solution vector of auto sharding.
        self.last_s_val = None
        # The last objective value of the best ILP solution.
        self.last_objective = None

    def _generate_node_index_dict(self) -> Dict[Node, int]:
        node_index_dict = {}
        for index, strategies_vector in enumerate(self.leaf_strategies):
            node_index_dict[strategies_vector.node.name] = index
        return node_index_dict

    def call_solver_serialized_args(self):
        """
        Call the solver with serialized arguments and handle python errors. Additionally,
        we could give a serious of solutions with different memory budget.
        """
        args = self._prepare_data_for_solver()
        ret = self._call_solver_serialized_args(*args)

        return ret

    def _prepare_data_for_solver(self):
        # return node_nums, strategies_len, following_nodes, edge_pairs, alias_set?, liveness_set?, compute_costs, communication_costs,  redistrubute_costs, alias_convert_costs?, s_init_np?
        # 需要folloeing_node是因为merge了
        node_nums = len(self.leaf_strategies)
        # prepare strategies_len
        strategies_len = []
        for node in self.nodes:
            strategies_len.append(self.cost_graph.node_lens[node])
        strategies_len = np.array(strategies_len)

        # prepare following_nodes
        # 无用因为没有简化图
        following_nodes = self.cost_graph.following_dict
        index_following_nodes = {}
        for src, target in following_nodes.items():
            src_index = self.node_index_dict[src]
            target_index = self.node_index_dict[target]
            index_following_nodes[src_index] = target_index
        following_nodes = index_following_nodes
        for index in range(node_nums):
            if index not in following_nodes:
                following_nodes[index] = -1


     # prepare edge_pairs and resharding costs
        edge_pairs = []
        redistribute_costs = []
        for pairs, edge_cost in self.cost_graph.edge_costs.items():
            src_node = pairs[0]
            dst_node = pairs[1]
            src_node_index = self.node_index_dict[src_node]
            dst_node_index = self.node_index_dict[dst_node]
            edge_pairs.append(src_node_index)
            edge_pairs.append(dst_node_index)

            # (0,5)为例 arg0-1,permute
            for i in range(strategies_len[src_node_index]):
                for j in range(strategies_len[dst_node_index]):
                    redistribute_costs.append(edge_cost[(i, j)])
        edge_pairs = np.array(edge_pairs)
        redistribute_costs = np.array(redistribute_costs)

        # prepare compute_costs, communication_costs and memory_costs
        compute_costs = []
        memory_costs = []
        communication_costs = []
        for strategies_vector in self.leaf_strategies:
            node = strategies_vector.node
            for index, strategy in enumerate(strategies_vector):

                communication_cost_item = strategy.communication_cost
                communication_costs.append(communication_cost_item)
                
                compute_cost_item = strategy.compute_cost
                memory_cost_item = strategy.memory_cost
                compute_costs.append(compute_cost_item)
                memory_costs.append(memory_cost_item)
        compute_costs = np.array(compute_costs)

        memory_costs = np.array(memory_costs)
        communication_costs = np.array(communication_costs)

        return node_nums,strategies_len,following_nodes,edge_pairs,compute_costs,memory_costs,communication_costs,redistribute_costs


    def _call_solver_serialized_args(self,
                                     node_nums,
                                     strategies_len,
                                     following_nodes,
                                     edge_pairs,
                                     compute_costs,
                                     memory_costs,
                                     communication_costs,
                                     redistribute_costs):
        """
        Call the solver with serialized arguments.
        """

        tic = time.time()

        for x in [strategies_len, edge_pairs, compute_costs, memory_costs, communication_costs,redistribute_costs]:
            assert isinstance(x, np.ndarray)
        assert len(strategies_len) == node_nums, "strategies_len"

        def get_non_zero_index(binary_vector):
            """
            Get the index of non-zero item in a vector.
            """
            ct = 0
            ret = None
            for i, elem in enumerate(binary_vector):
                if pulp.value(elem):
                    ret = i
                    ct += 1

            assert ct == 1
            return ret

        # 0. Unpack flatten numpy arrays
        # 用于检测和处理图中的边，确保没有重复的边，并根据每对边的节点策略长度计算重分片成本
        
        E = edge_pairs.reshape((-1, 2))    # (1035,2)(1011,2) -24个full
        # print(E.shape)
        r = []
        pt = 0
        edge_set = set()
        for (i, j) in E:
            prod_length = strategies_len[i] * strategies_len[j]

            if (i, j) in edge_set:
                raise ValueError(f"Duplicated edges: {(i, j)}")

            edge_set.add((i, j))
            # resharding_costs[pt:pt + prod_length] 就会取出这12个成本值，并将它们作为一个列表添加到 r 中。这样，r 列表中就会包含每对边的重分片成本
            r.append(redistribute_costs[pt:pt + prod_length])
            pt += prod_length
        assert pt == len(redistribute_costs)
        # print(len(r)) #1011
        ######################
        # omit alias set now #
        ######################

        # A = alias_set.reshape((-1, 2))  # noqa
        # for (i, j) in A:
        #     prod_length = strategies_len[i] * strategies_len[j]
        #     v.append(alias_convert_costs[pt:pt + prod_length])
        #     pt += prod_length
        # assert pt == len(alias_convert_costs)

        # L = []  # noqa
        # pt = node_nums
        # for i in range(node_nums):
        #     length = liveness_set[i]
        #     L.append(liveness_set[pt:pt + length])
        #     pt += length
        # assert pt == len(liveness_set)
        v = []
        pt = 0

        c = []
        m = []
        cc = []
        # qb = []
        pt = 0
        for i in range(node_nums):
            length = strategies_len[i]
            c.append(compute_costs[pt:pt + length])
            m.append(memory_costs[pt:pt + length])
            cc.append(communication_costs[pt:pt + length])
            # qb.append(qb_costs[pt:pt + length])
            # r.append(memory_costs[pt:pt + length])
            pt += length
        # assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}"
        assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}"
        # assert pt == len(memory_costs), f"{pt} == {len(memory_costs)}"

        # 1. Create variables

        #############################
        # create variables for node #
        #############################
        s_follow = following_nodes  #全是-1，没有优化939
        # s_alias ??
        s = []  #存的是每个节点的二维策略矩阵
        num_nodes = 0
        reverse_follow_backpatch = []
        for i in range(node_nums):
            if s_follow[i] < 0: #无跟随其他节点的策略
                if strategies_len[i] == 1:  #只有一个策略
                    s.append([1])
                elif strategies_len[i] > 1:
                    # if i not in s_alias:
                    num_nodes += 1
                    s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary"))   #创建一个二进制决策变量矩阵
                    # else:
                    #     s.append(s[s_alias[i]])
            else:
                if s_follow[i] < len(s):    #跟随的节点已经有策略变量了
                    s.append(s[s_follow[i]])
                else:
                    s.append(None)
                    reverse_follow_backpatch.append(i)  #添加到需要修正的列表，等到所有节点都有策略变量再修正

        for i in reverse_follow_backpatch:
            s[i] = s[s_follow[i]]
        # print("s:   ")
        # print(s)
        #############################
        # create variables for edge #
        #############################
        e = []
        num_edges = 0
        map_edge_to_idx = {}
        for (idx, (i, j)) in enumerate(E):  #遍历边10条边，11个节点
            if len(s[i]) == 1:  #说明i的策略已确定
                e.append(s[j])
            elif len(s[j]) == 1:
                e.append(s[i])
            else:
                # if i in s_alias and j in s_alias and (s_alias[i], s_alias[j]) in map_edge_to_idx:
                #     e.append(e[map_edge_to_idx[(s_alias[i], s_alias[j])]])
                # else:
                num_edges += 1
                e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary"))
            # if len(e[idx]) != len(r[idx]):
            #     print(i,j,len(e[idx]) , len(r[idx]))
            assert len(e[idx]) == len(r[idx])
            map_edge_to_idx[(i, j)] = idx
        for element in s:
            assert len(element) > 0
        # 2. Set initial value
        ######################################
        # set a initial value for warm start #
        ######################################
        # if s_init_np is not None:
        #     s_init = s_init_np.reshape((-1, 3))
        #     for (idx, value, fix) in s_init:
        #         for i in range(len(s[idx])):
        #             s[idx][i].setInitialValue(i == value)
        #             if fix:
        #                 s[idx][i].fixValue()
        print(f"#nodes: {num_nodes},  #edges: {num_edges}")

        # 3. Objective
        prob = LpProblem("myProblem", LpMinimize)
        ###################################################################
        # computing the node cost(computing cost and communication cost)  #
        ###################################################################
        obj = 0
        for i in range(node_nums):
            assert len(s[i]) == len(c[i])
            assert len(s[i]) == len(m[i])

            # obj += lpDot(s[i], c[i]) + lpDot(s[i], m[i]) + lpDot(s[i], cc[i])
            obj += lpDot(s[i], c[i]) + lpDot(s[i], cc[i])

        #############################################
        # computing the edge cost(resharding cost)  #
        #############################################

        for i in range(len(E)): #边
            assert len(e[i]) == len(r[i])
            obj += lpDot(e[i], r[i])
            # add 为了使重排代价为0
            prob += lpDot(e[i], r[i]) == 0

        prob += obj

        # 4. Constraints
        # (a). specified by `cat="Binary"`

        # (b)
        #################################################
        # make sure each node only choose one strategy  #
        #################################################
        for i in range(node_nums):
            if s_follow[i] < 0:
                prob += lpSum(s[i]) == 1

        # (c)
        #################################################
        # compute memory consumption with liveness set  #
        #################################################
        # if memory_budget > 0:
        #     mem = 0
        #     for node in liveness_set:
        #         if node not in self.node_index_dict:
        #             continue
        #         node_index = self.node_index_dict[node]
        #         mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index])))
        #         prob += mem <= memory_budget

        # (d). specified by `cat="Binary"`

        for (idx, (i, j)) in enumerate(E):
            if strategies_len[i] == 1 or strategies_len[j] == 1:
                continue

            # (e)
            prob += lpSum(e[idx]) == 1

            # (f)
            for row in range(len(s[i])):
                C = len(s[j])    # noqa
                prob += lpSum(e[idx][row * C + col] for col in range(0, C)) <= s[i][row]

            # (g)
            for col in range(len(s[j])):
                R = len(s[i])    # noqa
                C = len(s[j])    # noqa
                prob += lpSum(e[idx][row * C + col] for row in range(0, R)) <= s[j][col]

        # (h)
        ######################
        # omit alias set now #
        ######################

        # alias_set = set()
        # for (idx, (i, j)) in enumerate(A):
        #     R = len(s[i])  # noqa
        #     C = len(s[j])  # noqa
        #     if (i, j) in alias_set:
        #         raise ValueError(f"Duplicated edges: {(i, j)}")

        #     alias_set.add((i, j))
        #     alias_set.add((j, i))

        #     for row in range(len(s[i])):
        #         for col in range(len(s[j])):
        #             if v[idx][row * C + col] > 0.5:
        #                 prob += s[i][row] + s[j][col] <= 1

        msg = True
        time_limit = 600
        assert "COIN_CMD" in pulp.listSolvers(
            onlyAvailable=True), ("Please install ILP solvers by 'sudo apt install coinor-cbc'")

        solver = pulp.COIN_CMD(mip=True, msg=msg, timeLimit=time_limit) #删除了thread属性
        # solver = pulp.GLPK_CMD(mip=True, msg=msg, timeLimit=time_limit)
        prob.solve(solver)

        status = prob.status
        objective = pulp.value(prob.objective)
        objective = float(objective) if objective is not None else -1.0
        # if verbose:
        print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t"
                f"Time: {time.time() - tic}")
        print(f"#nodes: {num_nodes},  #edges: {num_edges}")

        if prob.status in [pulp.LpStatusInfeasible]:
            raise RuntimeError("Cannot run the function under the given memory budget. "
                               "Please increase the memory budget.")

        # Get and check results
        s_val = np.full((node_nums,), -1, dtype=np.int32)
        for i in range(node_nums):
            s_val[i] = get_non_zero_index(s[i])

        e_val = np.full((len(E),), -1, dtype=np.int32)
        for (idx, (i, j)) in enumerate(E):
            e_val[idx] = get_non_zero_index(e[idx])
            i_spec_index = e_val[idx] // len(s[j])
            j_spec_index = e_val[idx] % len(s[j])
            assert i_spec_index == s_val[i], f"e_val[{i}][{j}]"
            assert j_spec_index == s_val[j], f"e_val[{i}][{j}]"
            if r[idx][e_val[idx]] > 0:
                print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}")

        self.last_s_val = list(s_val)
        # self._recover_merged_node_strategy()
        self.last_objective = objective

        if objective > INFINITY_COST:
            warnings.warn("Detect unexpected behaviors in the auto-sharding pass.")

        return self.last_s_val, e_val, self.last_objective, status